{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/mrdbourke/pytorch-deep-learning/blob/main/video_notebooks/05_pytorch_going_modular_cell_mode_video.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1M5tsObEa17-"
      },
      "source": [
        "# 05. Going Modular: Part 1 (cell mode)\n",
        "\n",
        "This notebook is part 1/2 of section [05. Going Modular](https://www.learnpytorch.io/05_pytorch_going_modular/).\n",
        "\n",
        "For reference, the two parts are: \n",
        "1. [**05. Going Modular: Part 1 (cell mode)**](https://github.com/mrdbourke/pytorch-deep-learning/blob/main/going_modular/05_pytorch_going_modular_cell_mode.ipynb) - this notebook is run as a traditional Jupyter Notebook/Google Colab notebook and is a condensed version of [notebook 04](https://www.learnpytorch.io/04_pytorch_custom_datasets/).\n",
        "2. [**05. Going Modular: Part 2 (script mode)**](https://github.com/mrdbourke/pytorch-deep-learning/blob/main/going_modular/05_pytorch_going_modular_script_mode.ipynb) - this notebook is the same as number 1 but with added functionality to turn each of the major sections into Python scripts, such as, `data_setup.py` and `train.py`. \n",
        "\n",
        "Why two parts?\n",
        "\n",
        "Because sometimes the best way to learn something is to see how it *differs* from something else.\n",
        "\n",
        "If you run each notebook side-by-side you'll see how they differ and that's where the key learnings are."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "szBKtTCTkKZj"
      },
      "source": [
        "## What is cell mode?\n",
        "\n",
        "A cell mode notebook is a regular notebook run exactly how we've been running them through the course.\n",
        "\n",
        "Some cells contain text and others contain code."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-uLD7A-4kKZj"
      },
      "source": [
        "## What's the difference between this notebook (Part 1) and the script mode notebook (Part 2)?\n",
        "\n",
        "This notebook, 05. PyTorch Going Modular: Part 1 (cell mode), runs a cleaned up version of the most useful code from section [04. PyTorch Custom Datasets](https://www.learnpytorch.io/04_pytorch_custom_datasets/).\n",
        "\n",
        "Running this notebook end-to-end will result in recreating the image classification model we built in notebook 04 (TinyVGG) trained on images of pizza, steak and sushi.\n",
        "\n",
        "The main difference between this notebook (Part 1) and Part 2 is that each section in Part 2 (script mode) has an extra subsection (e.g. 2.1, 3.1, 4.1) for turning cell code into script code."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "B33tS78_kKZk"
      },
      "source": [
        "## Where can you get help?\n",
        "\n",
        "You can find the book version of this section [05. PyTorch Going Modular on learnpytorch.io](https://www.learnpytorch.io/05_pytorch_going_modular/).\n",
        "\n",
        "The rest of the materials for this course [are available on GitHub](https://github.com/mrdbourke/pytorch-deep-learning).\n",
        "\n",
        "If you run into trouble, you can ask a question on the course [GitHub Discussions page](https://github.com/mrdbourke/pytorch-deep-learning/discussions).\n",
        "\n",
        "And of course, there's the [PyTorch documentation](https://pytorch.org/docs/stable/index.html) and [PyTorch developer forums](https://discuss.pytorch.org/), a very helpful place for all things PyTorch. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I_0BPjLcawE-"
      },
      "source": [
        "## 0. Running a notebook in cell mode\n",
        "\n",
        "As discussed, we're going to be running this notebook normally.\n",
        "\n",
        "One cell at a time.\n",
        "\n",
        "The code is from notebook 04, however, it has been condensed down to its core functionality."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1XgXFr97bCsI"
      },
      "source": [
        "## 1. Get data\n",
        "\n",
        "We're going to start by downloading the same data we used in [notebook 04](https://www.learnpytorch.io/04_pytorch_custom_datasets/#1-get-data), the `pizza_steak_sushi` dataset with images of pizza, steak and sushi."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "wl_jfPzVbGDS",
        "outputId": "b6adca84-062b-4166-ceba-950067997201"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Did not find data/pizza_steak_sushi directory, creating one...\n",
            "Downloading pizza, steak, sushi data...\n",
            "Unzipping pizza, steak, sushi data...\n"
          ]
        }
      ],
      "source": [
        "import os\n",
        "import zipfile\n",
        "\n",
        "from pathlib import Path\n",
        "\n",
        "import requests\n",
        "\n",
        "# Setup path to data folder\n",
        "data_path = Path(\"data/\")\n",
        "image_path = data_path / \"pizza_steak_sushi\"\n",
        "\n",
        "# If the image folder doesn't exist, download it and prepare it... \n",
        "if image_path.is_dir():\n",
        "    print(f\"{image_path} directory exists.\")\n",
        "else:\n",
        "    print(f\"Did not find {image_path} directory, creating one...\")\n",
        "    image_path.mkdir(parents=True, exist_ok=True)\n",
        "    \n",
        "# Download pizza, steak, sushi data\n",
        "with open(data_path / \"pizza_steak_sushi.zip\", \"wb\") as f:\n",
        "    request = requests.get(\"https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip\")\n",
        "    print(\"Downloading pizza, steak, sushi data...\")\n",
        "    f.write(request.content)\n",
        "\n",
        "# Unzip pizza, steak, sushi data\n",
        "with zipfile.ZipFile(data_path / \"pizza_steak_sushi.zip\", \"r\") as zip_ref:\n",
        "    print(\"Unzipping pizza, steak, sushi data...\") \n",
        "    zip_ref.extractall(image_path)\n",
        "    \n",
        "# Remove zip file\n",
        "os.remove(data_path / \"pizza_steak_sushi.zip\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "8qYMdiBMbde0",
        "outputId": "0aa26939-236e-401c-d147-2898698bf9e5"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(PosixPath('data/pizza_steak_sushi/train'),\n",
              " PosixPath('data/pizza_steak_sushi/test'))"
            ]
          },
          "metadata": {},
          "execution_count": 4
        }
      ],
      "source": [
        "# Setup train and testing paths\n",
        "train_dir = image_path / \"train\"\n",
        "test_dir = image_path / \"test\"\n",
        "\n",
        "train_dir, test_dir"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eYL6NynybU9Z"
      },
      "source": [
        "## 2. Create Datasets and DataLoaders\n",
        "\n",
        "Now we'll turn the image dataset into PyTorch `Dataset`'s and `DataLoader`'s. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "DaOJOd-QbOQk",
        "outputId": "a6fc9ad5-763d-48c5-aa8b-3c7d381b9740"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Train data:\n",
            "Dataset ImageFolder\n",
            "    Number of datapoints: 225\n",
            "    Root location: data/pizza_steak_sushi/train\n",
            "    StandardTransform\n",
            "Transform: Compose(\n",
            "               Resize(size=(64, 64), interpolation=bilinear, max_size=None, antialias=None)\n",
            "               ToTensor()\n",
            "           )\n",
            "Test data:\n",
            "Dataset ImageFolder\n",
            "    Number of datapoints: 75\n",
            "    Root location: data/pizza_steak_sushi/test\n",
            "    StandardTransform\n",
            "Transform: Compose(\n",
            "               Resize(size=(64, 64), interpolation=bilinear, max_size=None, antialias=None)\n",
            "               ToTensor()\n",
            "           )\n"
          ]
        }
      ],
      "source": [
        "from torchvision import datasets, transforms\n",
        "\n",
        "# Create simple transform\n",
        "data_transform = transforms.Compose([ \n",
        "    transforms.Resize((64, 64)),\n",
        "    transforms.ToTensor(),\n",
        "])\n",
        "\n",
        "# Use ImageFolder to create dataset(s)\n",
        "train_data = datasets.ImageFolder(root=train_dir, # target folder of images\n",
        "                                  transform=data_transform, # transforms to perform on data (images)\n",
        "                                  target_transform=None) # transforms to perform on labels (if necessary)\n",
        "\n",
        "test_data = datasets.ImageFolder(root=test_dir, \n",
        "                                 transform=data_transform)\n",
        "\n",
        "print(f\"Train data:\\n{train_data}\\nTest data:\\n{test_data}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "eEln0VFmbfyY",
        "outputId": "ee9c518c-f61f-47cd-ec3b-310f1a19bc3e"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['pizza', 'steak', 'sushi']"
            ]
          },
          "metadata": {},
          "execution_count": 6
        }
      ],
      "source": [
        "# Get class names as a list\n",
        "class_names = train_data.classes\n",
        "class_names"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "IMGTFVr8bgwE",
        "outputId": "9269f0bc-9154-44b5-9ffd-101d7c46ea24"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'pizza': 0, 'steak': 1, 'sushi': 2}"
            ]
          },
          "metadata": {},
          "execution_count": 7
        }
      ],
      "source": [
        "# Can also get class names as a dict\n",
        "class_dict = train_data.class_to_idx\n",
        "class_dict"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6a8IW-IYbhGX",
        "outputId": "8c2eebfd-1258-42be-8c65-841094a9792d"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(225, 75)"
            ]
          },
          "metadata": {},
          "execution_count": 8
        }
      ],
      "source": [
        "# Check the lengths\n",
        "len(train_data), len(test_data)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "f1UxnvvmblUj",
        "outputId": "d127b346-8611-46fe-f9df-8f05e59b4d48"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(<torch.utils.data.dataloader.DataLoader at 0x7fed03d16b50>,\n",
              " <torch.utils.data.dataloader.DataLoader at 0x7fed03d16590>)"
            ]
          },
          "metadata": {},
          "execution_count": 9
        }
      ],
      "source": [
        "# Turn train and test Datasets into DataLoaders\n",
        "from torch.utils.data import DataLoader\n",
        "train_dataloader = DataLoader(dataset=train_data, \n",
        "                              batch_size=1, # how many samples per batch?\n",
        "                              num_workers=1, # how many subprocesses to use for data loading? (higher = more)\n",
        "                              shuffle=True) # shuffle the data?\n",
        "\n",
        "test_dataloader = DataLoader(dataset=test_data, \n",
        "                             batch_size=1, \n",
        "                             num_workers=1, \n",
        "                             shuffle=False) # don't usually need to shuffle testing data\n",
        "\n",
        "train_dataloader, test_dataloader"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "JJKRqXNGbnI5",
        "outputId": "9a47cee6-2fc7-4b84-b035-0f71130f23df"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Image shape: torch.Size([1, 3, 64, 64]) -> [batch_size, color_channels, height, width]\n",
            "Label shape: torch.Size([1])\n"
          ]
        }
      ],
      "source": [
        "# Check out single image size/shape\n",
        "img, label = next(iter(train_dataloader))\n",
        "\n",
        "# Batch size will now be 1, try changing the batch_size parameter above and see what happens\n",
        "print(f\"Image shape: {img.shape} -> [batch_size, color_channels, height, width]\")\n",
        "print(f\"Label shape: {label.shape}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### 2.1 Create Datasets and DataLoaders (script mode) \n",
        "\n",
        "Let's use the Jupyter magic function to create a `.py` file for creating DataLoaders. \n",
        "\n",
        "We can save a code cell's contents to a file using the Jupyter magic `%%writefile filename` - https://ipython.readthedocs.io/en/stable/interactive/magics.html#cellmagic-writefile "
      ],
      "metadata": {
        "id": "uwVODowqxUba"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Create a directory going_modular scripts\n",
        "import os\n",
        "os.makedirs(\"going_modular\")"
      ],
      "metadata": {
        "id": "aIeUk5u5ytlA"
      },
      "execution_count": 13,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "%%writefile going_modular/data_setup.py\n",
        "\"\"\"\n",
        "Contains functionality for creating PyTorch DataLoader's for\n",
        "image classification data.\n",
        "\"\"\"\n",
        "import os\n",
        "\n",
        "from torchvision import datasets, transforms\n",
        "from torch.utils.data import DataLoader\n",
        "\n",
        "NUM_WORKERS = os.cpu_count()\n",
        "\n",
        "def create_dataloaders(\n",
        "  train_dir: str,\n",
        "  test_dir: str,\n",
        "  transform: transforms.Compose,\n",
        "  batch_size: int,\n",
        "  num_workers: int=NUM_WORKERS  \n",
        "):\n",
        "  \"\"\"Creates training and testing DataLoaders.\n",
        "\n",
        "  Takes in a training directory and testing directroy path and turns them into \n",
        "  PyTorch Datasets and then into PyTorch DataLoaders.\n",
        "\n",
        "  Args:\n",
        "    train_dir: Path to training directory.\n",
        "    test_dir: Path to testing directory.\n",
        "    transform: torchvision transforms to perform on training and testing data.\n",
        "    batch_size: Number of samples per batch in each of the DataLoaders.\n",
        "    num_workers: An integer for number of workers per DataLoader.\n",
        "\n",
        "  Returns:\n",
        "    A tuple of (train_dataloader, test_dataloader, class_names).\n",
        "    Where class_names is a list of the target classes.\n",
        "    Example usage:\n",
        "      train_dataloader, test_dataloader, class_names = create_dataloaders(train_dir=path/to/train_dir,\n",
        "        test_dir=path/to/test_dir,\n",
        "        transform=some_transform,\n",
        "        batch_size=32,\n",
        "        num_workers=4)\n",
        "  \"\"\"\n",
        "  # Use ImageFolder to create datasets(s)\n",
        "  train_data = datasets.ImageFolder(train_dir, transform=transform)\n",
        "  test_data = datasets.ImageFolder(test_dir, transform=transform)\n",
        "\n",
        "  # Get class names\n",
        "  class_names = train_data.classes\n",
        "\n",
        "  # Turn images into DataLoaders\n",
        "  train_dataloader = DataLoader(\n",
        "      train_data,\n",
        "      batch_size=batch_size,\n",
        "      shuffle=True,\n",
        "      num_workers=num_workers,\n",
        "      pin_memory=True # for more on pin memory, see the PyTorch docs: https://pytorch.org/docs/stable/data.html \n",
        "  )\n",
        "\n",
        "  test_dataloader = DataLoader(\n",
        "      test_data,\n",
        "      batch_size=batch_size,\n",
        "      shuffle=False,\n",
        "      num_workers=num_workers,\n",
        "      pin_memory=True\n",
        "  )\n",
        "\n",
        "  return train_dataloader, test_dataloader, class_names"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "7q2bWHYrxa67",
        "outputId": "790ab214-097d-439f-9821-99e8287a1bf8"
      },
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Overwriting going_modular/data_setup.py\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from going_modular import data_setup\n",
        "\n",
        "train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(train_dir=train_dir,\n",
        "                                                                               test_dir=test_dir,\n",
        "                                                                               transform=data_transform,\n",
        "                                                                               batch_size=32)\n",
        "train_dataloader, test_dataloader, class_names"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "71NsJ8df2lsv",
        "outputId": "e240aafd-e101-40a0-971c-41a6a33aeb5c"
      },
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(<torch.utils.data.dataloader.DataLoader at 0x7fec47a09490>,\n",
              " <torch.utils.data.dataloader.DataLoader at 0x7fec47a09610>,\n",
              " ['pizza', 'steak', 'sushi'])"
            ]
          },
          "metadata": {},
          "execution_count": 19
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gAQLp1dCbrY0"
      },
      "source": [
        "## 3. Making a model (TinyVGG)\n",
        "\n",
        "We're going to use the same model we used in notebook 04: TinyVGG from the CNN Explainer website.\n",
        "\n",
        "The only change here from notebook 04 is that a docstring has been added using [Google's Style Guide for Python](https://google.github.io/styleguide/pyguide.html#384-classes). "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "metadata": {
        "id": "Z6i2xWcwbvom"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "\n",
        "from torch import nn\n",
        "\n",
        "class TinyVGG(nn.Module):\n",
        "  \"\"\"Creates the TinyVGG architecture.\n",
        "\n",
        "  Replicates the TinyVGG architecture from the CNN explainer website in PyTorch.\n",
        "  See the original architecture here: https://poloclub.github.io/cnn-explainer/\n",
        "  \n",
        "  Args:\n",
        "    input_shape: An integer indicating number of input channels.\n",
        "    hidden_units: An integer indicating number of hidden units between layers.\n",
        "    output_shape: An integer indicating number of output units.\n",
        "  \"\"\"\n",
        "  def __init__(self, input_shape: int, hidden_units: int, output_shape: int) -> None:\n",
        "      super().__init__()\n",
        "      self.conv_block_1 = nn.Sequential(\n",
        "          nn.Conv2d(in_channels=input_shape, \n",
        "                    out_channels=hidden_units, \n",
        "                    kernel_size=3, # how big is the square that's going over the image?\n",
        "                    stride=1, # default\n",
        "                    padding=0), # options = \"valid\" (no padding) or \"same\" (output has same shape as input) or int for specific number \n",
        "          nn.ReLU(),\n",
        "          nn.Conv2d(in_channels=hidden_units, \n",
        "                    out_channels=hidden_units,\n",
        "                    kernel_size=3,\n",
        "                    stride=1,\n",
        "                    padding=0),\n",
        "          nn.ReLU(),\n",
        "          nn.MaxPool2d(kernel_size=2,\n",
        "                        stride=2) # default stride value is same as kernel_size\n",
        "      )\n",
        "      self.conv_block_2 = nn.Sequential(\n",
        "          nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=0),\n",
        "          nn.ReLU(),\n",
        "          nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=0),\n",
        "          nn.ReLU(),\n",
        "          nn.MaxPool2d(2)\n",
        "      )\n",
        "      self.classifier = nn.Sequential(\n",
        "          nn.Flatten(),\n",
        "          # Where did this in_features shape come from? \n",
        "          # It's because each layer of our network compresses and changes the shape of our inputs data.\n",
        "          nn.Linear(in_features=hidden_units*13*13,\n",
        "                    out_features=output_shape)\n",
        "      )\n",
        "    \n",
        "  def forward(self, x: torch.Tensor):\n",
        "      x = self.conv_block_1(x)\n",
        "      x = self.conv_block_2(x)\n",
        "      x = self.classifier(x)\n",
        "      return x\n",
        "      # return self.classifier(self.block_2(self.block_1(x))) # <- leverage the benefits of operator fusion"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "tXOOAEHHc6-L",
        "outputId": "91f3fd8a-f419-4bfc-983b-71633884476a"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "TinyVGG(\n",
              "  (conv_block_1): Sequential(\n",
              "    (0): Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1))\n",
              "    (1): ReLU()\n",
              "    (2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))\n",
              "    (3): ReLU()\n",
              "    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
              "  )\n",
              "  (conv_block_2): Sequential(\n",
              "    (0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))\n",
              "    (1): ReLU()\n",
              "    (2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))\n",
              "    (3): ReLU()\n",
              "    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
              "  )\n",
              "  (classifier): Sequential(\n",
              "    (0): Flatten(start_dim=1, end_dim=-1)\n",
              "    (1): Linear(in_features=1690, out_features=3, bias=True)\n",
              "  )\n",
              ")"
            ]
          },
          "metadata": {},
          "execution_count": 21
        }
      ],
      "source": [
        "import torch\n",
        "\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "\n",
        "# Instantiate an instance of the model\n",
        "torch.manual_seed(42)\n",
        "model_0 = TinyVGG(input_shape=3, # number of color channels (3 for RGB) \n",
        "                  hidden_units=10, \n",
        "                  output_shape=len(train_data.classes)).to(device)\n",
        "model_0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "C1CwvPiMcEfZ"
      },
      "source": [
        "To test our model let's do a single forward pass (pass a sample batch from the training set through our model)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "__p5G2ArcFko",
        "outputId": "c8de0391-df9d-40b9-d34e-7c991aa419a9"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Single image shape: torch.Size([1, 3, 64, 64])\n",
            "\n",
            "Output logits:\n",
            "tensor([[ 0.0208, -0.0019,  0.0095]], device='cuda:0')\n",
            "\n",
            "Output prediction probabilities:\n",
            "tensor([[0.3371, 0.3295, 0.3333]], device='cuda:0')\n",
            "\n",
            "Output prediction label:\n",
            "tensor([0], device='cuda:0')\n",
            "\n",
            "Actual label:\n",
            "0\n"
          ]
        }
      ],
      "source": [
        "# 1. Get a batch of images and labels from the DataLoader\n",
        "img_batch, label_batch = next(iter(train_dataloader))\n",
        "\n",
        "# 2. Get a single image from the batch and unsqueeze the image so its shape fits the model\n",
        "img_single, label_single = img_batch[0].unsqueeze(dim=0), label_batch[0]\n",
        "print(f\"Single image shape: {img_single.shape}\\n\")\n",
        "\n",
        "# 3. Perform a forward pass on a single image\n",
        "model_0.eval()\n",
        "with torch.inference_mode():\n",
        "    pred = model_0(img_single.to(device))\n",
        "    \n",
        "# 4. Print out what's happening and convert model logits -> pred probs -> pred label\n",
        "print(f\"Output logits:\\n{pred}\\n\")\n",
        "print(f\"Output prediction probabilities:\\n{torch.softmax(pred, dim=1)}\\n\")\n",
        "print(f\"Output prediction label:\\n{torch.argmax(torch.softmax(pred, dim=1), dim=1)}\\n\")\n",
        "print(f\"Actual label:\\n{label_single}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### 3.1. Making a model (TinyVGG) with a script (`model_builder.py`)\n",
        "\n",
        "Let's turn our model building code into a Python script we can import."
      ],
      "metadata": {
        "id": "vzReGPag4pe4"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "%%writefile going_modular/model_builder.py \n",
        "\"\"\"\n",
        "Contains PyTorch model code to instantiate a TinyVGG model from the CNN Explainer website.\n",
        "\"\"\"\n",
        "import torch\n",
        "\n",
        "from torch import nn\n",
        "\n",
        "class TinyVGG(nn.Module):\n",
        "  \"\"\"Creates the TinyVGG architecture.\n",
        "\n",
        "  Replicates the TinyVGG architecture from the CNN explainer website in PyTorch.\n",
        "  See the original architecture here: https://poloclub.github.io/cnn-explainer/\n",
        "  \n",
        "  Args:\n",
        "    input_shape: An integer indicating number of input channels.\n",
        "    hidden_units: An integer indicating number of hidden units between layers.\n",
        "    output_shape: An integer indicating number of output units.\n",
        "  \"\"\"\n",
        "  def __init__(self, input_shape: int, hidden_units: int, output_shape: int) -> None:\n",
        "      super().__init__()\n",
        "      self.conv_block_1 = nn.Sequential(\n",
        "          nn.Conv2d(in_channels=input_shape, \n",
        "                    out_channels=hidden_units, \n",
        "                    kernel_size=3, # how big is the square that's going over the image?\n",
        "                    stride=1, # default\n",
        "                    padding=0), # options = \"valid\" (no padding) or \"same\" (output has same shape as input) or int for specific number \n",
        "          nn.ReLU(),\n",
        "          nn.Conv2d(in_channels=hidden_units, \n",
        "                    out_channels=hidden_units,\n",
        "                    kernel_size=3,\n",
        "                    stride=1,\n",
        "                    padding=0),\n",
        "          nn.ReLU(),\n",
        "          nn.MaxPool2d(kernel_size=2,\n",
        "                        stride=2) # default stride value is same as kernel_size\n",
        "      )\n",
        "      self.conv_block_2 = nn.Sequential(\n",
        "          nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=0),\n",
        "          nn.ReLU(),\n",
        "          nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=0),\n",
        "          nn.ReLU(),\n",
        "          nn.MaxPool2d(2)\n",
        "      )\n",
        "      self.classifier = nn.Sequential(\n",
        "          nn.Flatten(),\n",
        "          # Where did this in_features shape come from? \n",
        "          # It's because each layer of our network compresses and changes the shape of our inputs data.\n",
        "          nn.Linear(in_features=hidden_units*13*13,\n",
        "                    out_features=output_shape)\n",
        "      )\n",
        "    \n",
        "  def forward(self, x: torch.Tensor):\n",
        "      x = self.conv_block_1(x)\n",
        "      x = self.conv_block_2(x)\n",
        "      x = self.classifier(x)\n",
        "      return x\n",
        "      # return self.classifier(self.block_2(self.block_1(x))) # <- leverage the benefits of operator fusion"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "uEKPzMvx4zxd",
        "outputId": "f106f0f3-785a-49f5-c3fb-8fe4f8b14c3b"
      },
      "execution_count": 25,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Overwriting going_modular/model_builder.py\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from going_modular import model_builder"
      ],
      "metadata": {
        "id": "S9Tp9O7E5sAG"
      },
      "execution_count": 26,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "\n",
        "from going_modular import model_builder\n",
        "\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "\n",
        "# Instantiate a model from the model_builder.py script\n",
        "torch.manual_seed(42)\n",
        "model_1 = model_builder.TinyVGG(input_shape=3,\n",
        "                                hidden_units=10,\n",
        "                                output_shape=len(class_names)).to(device)\n",
        "model_1"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "li_BWBjP5WuV",
        "outputId": "c64193b0-95d3-43bb-a4e3-8d1c88ad3f27"
      },
      "execution_count": 27,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "TinyVGG(\n",
              "  (conv_block_1): Sequential(\n",
              "    (0): Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1))\n",
              "    (1): ReLU()\n",
              "    (2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))\n",
              "    (3): ReLU()\n",
              "    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
              "  )\n",
              "  (conv_block_2): Sequential(\n",
              "    (0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))\n",
              "    (1): ReLU()\n",
              "    (2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))\n",
              "    (3): ReLU()\n",
              "    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
              "  )\n",
              "  (classifier): Sequential(\n",
              "    (0): Flatten(start_dim=1, end_dim=-1)\n",
              "    (1): Linear(in_features=1690, out_features=3, bias=True)\n",
              "  )\n",
              ")"
            ]
          },
          "metadata": {},
          "execution_count": 27
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# 1. Get a batch of images and labels from the DataLoader\n",
        "img_batch, label_batch = next(iter(train_dataloader))\n",
        "\n",
        "# 2. Get a single image from the batch and unsqueeze the image so its shape fits the model\n",
        "img_single, label_single = img_batch[0].unsqueeze(dim=0), label_batch[0]\n",
        "print(f\"Single image shape: {img_single.shape}\\n\")\n",
        "\n",
        "# 3. Perform a forward pass on a single image\n",
        "model_1.eval()\n",
        "with torch.inference_mode():\n",
        "    pred = model_1(img_single.to(device))\n",
        "    \n",
        "# 4. Print out what's happening and convert model logits -> pred probs -> pred label\n",
        "print(f\"Output logits:\\n{pred}\\n\")\n",
        "print(f\"Output prediction probabilities:\\n{torch.softmax(pred, dim=1)}\\n\")\n",
        "print(f\"Output prediction label:\\n{torch.argmax(torch.softmax(pred, dim=1), dim=1)}\\n\")\n",
        "print(f\"Actual label:\\n{label_single}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "x5VFXm3h5-qt",
        "outputId": "23789285-dbaa-485f-d3bd-6dcdbdf8dcdd"
      },
      "execution_count": 28,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Single image shape: torch.Size([1, 3, 64, 64])\n",
            "\n",
            "Output logits:\n",
            "tensor([[ 0.0208, -0.0019,  0.0095]], device='cuda:0')\n",
            "\n",
            "Output prediction probabilities:\n",
            "tensor([[0.3371, 0.3295, 0.3333]], device='cuda:0')\n",
            "\n",
            "Output prediction label:\n",
            "tensor([0], device='cuda:0')\n",
            "\n",
            "Actual label:\n",
            "0\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8OB4xIsXcHrp"
      },
      "source": [
        "## 4. Creating `train_step()` and `test_step()` functions and `train()` to combine them  \n",
        "\n",
        "Rather than writing them again, we can reuse the `train_step()` and `test_step()` functions from [notebook 04](https://www.learnpytorch.io/04_pytorch_custom_datasets/#75-create-train-test-loop-functions).\n",
        "\n",
        "The same goes for the `train()` function we created.\n",
        "\n",
        "The only difference here is that these functions have had docstrings added to them in [Google's Python Functions and Methods Style Guide](https://google.github.io/styleguide/pyguide.html#383-functions-and-methods).\n",
        "\n",
        "Let's start by making `train_step()`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 29,
      "metadata": {
        "id": "t_0zRYZecJZ6"
      },
      "outputs": [],
      "source": [
        "from typing import Tuple\n",
        "\n",
        "def train_step(model: torch.nn.Module, \n",
        "               dataloader: torch.utils.data.DataLoader, \n",
        "               loss_fn: torch.nn.Module, \n",
        "               optimizer: torch.optim.Optimizer,\n",
        "               device: torch.device) -> Tuple[float, float]:\n",
        "  \"\"\"Trains a PyTorch model for a single epoch.\n",
        "\n",
        "  Turns a target PyTorch model to training mode and then\n",
        "  runs through all of the required training steps (forward\n",
        "  pass, loss calculation, optimizer step).\n",
        "\n",
        "  Args:\n",
        "    model: A PyTorch model to be trained.\n",
        "    dataloader: A DataLoader instance for the model to be trained on.\n",
        "    loss_fn: A PyTorch loss function to minimize.\n",
        "    optimizer: A PyTorch optimizer to help minimize the loss function.\n",
        "    device: A target device to compute on (e.g. \"cuda\" or \"cpu\").\n",
        "\n",
        "  Returns:\n",
        "    A tuple of training loss and training accuracy metrics.\n",
        "    In the form (train_loss, train_accuracy). For example:\n",
        "    \n",
        "    (0.1112, 0.8743)\n",
        "  \"\"\"\n",
        "  # Put model in train mode\n",
        "  model.train()\n",
        "  \n",
        "  # Setup train loss and train accuracy values\n",
        "  train_loss, train_acc = 0, 0\n",
        "  \n",
        "  # Loop through data loader data batches\n",
        "  for batch, (X, y) in enumerate(dataloader):\n",
        "      # Send data to target device\n",
        "      X, y = X.to(device), y.to(device)\n",
        "\n",
        "      # 1. Forward pass\n",
        "      y_pred = model(X)\n",
        "\n",
        "      # 2. Calculate  and accumulate loss\n",
        "      loss = loss_fn(y_pred, y)\n",
        "      train_loss += loss.item() \n",
        "\n",
        "      # 3. Optimizer zero grad\n",
        "      optimizer.zero_grad()\n",
        "\n",
        "      # 4. Loss backward\n",
        "      loss.backward()\n",
        "\n",
        "      # 5. Optimizer step\n",
        "      optimizer.step()\n",
        "\n",
        "      # Calculate and accumulate accuracy metric across all batches\n",
        "      y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)\n",
        "      train_acc += (y_pred_class == y).sum().item()/len(y_pred)\n",
        "\n",
        "  # Adjust metrics to get average loss and accuracy per batch \n",
        "  train_loss = train_loss / len(dataloader)\n",
        "  train_acc = train_acc / len(dataloader)\n",
        "  return train_loss, train_acc"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7JQvCWvr-db-"
      },
      "source": [
        "Now we'll do `test_step()`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 30,
      "metadata": {
        "id": "sr_9AspYcKkW"
      },
      "outputs": [],
      "source": [
        "def test_step(model: torch.nn.Module, \n",
        "              dataloader: torch.utils.data.DataLoader, \n",
        "              loss_fn: torch.nn.Module,\n",
        "              device: torch.device) -> Tuple[float, float]:\n",
        "  \"\"\"Tests a PyTorch model for a single epoch.\n",
        "\n",
        "  Turns a target PyTorch model to \"eval\" mode and then performs\n",
        "  a forward pass on a testing dataset.\n",
        "\n",
        "  Args:\n",
        "    model: A PyTorch model to be tested.\n",
        "    dataloader: A DataLoader instance for the model to be tested on.\n",
        "    loss_fn: A PyTorch loss function to calculate loss on the test data.\n",
        "    device: A target device to compute on (e.g. \"cuda\" or \"cpu\").\n",
        "\n",
        "  Returns:\n",
        "    A tuple of testing loss and testing accuracy metrics.\n",
        "    In the form (test_loss, test_accuracy). For example:\n",
        "    \n",
        "    (0.0223, 0.8985)\n",
        "  \"\"\"\n",
        "  # Put model in eval mode\n",
        "  model.eval() \n",
        "  \n",
        "  # Setup test loss and test accuracy values\n",
        "  test_loss, test_acc = 0, 0\n",
        "  \n",
        "  # Turn on inference context manager\n",
        "  with torch.inference_mode():\n",
        "      # Loop through DataLoader batches\n",
        "      for batch, (X, y) in enumerate(dataloader):\n",
        "          # Send data to target device\n",
        "          X, y = X.to(device), y.to(device)\n",
        "  \n",
        "          # 1. Forward pass\n",
        "          test_pred_logits = model(X)\n",
        "\n",
        "          # 2. Calculate and accumulate loss\n",
        "          loss = loss_fn(test_pred_logits, y)\n",
        "          test_loss += loss.item()\n",
        "          \n",
        "          # Calculate and accumulate accuracy\n",
        "          test_pred_labels = test_pred_logits.argmax(dim=1)\n",
        "          test_acc += ((test_pred_labels == y).sum().item()/len(test_pred_labels))\n",
        "          \n",
        "  # Adjust metrics to get average loss and accuracy per batch \n",
        "  test_loss = test_loss / len(dataloader)\n",
        "  test_acc = test_acc / len(dataloader)\n",
        "  return test_loss, test_acc"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XkQJke3S-npO"
      },
      "source": [
        "And we'll combine `train_step()` and `test_step()` into `train()`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 32,
      "metadata": {
        "id": "-oze8b6icWE7"
      },
      "outputs": [],
      "source": [
        "from typing import Dict, List\n",
        "\n",
        "from tqdm.auto import tqdm\n",
        "\n",
        "def train(model: torch.nn.Module, \n",
        "          train_dataloader: torch.utils.data.DataLoader, \n",
        "          test_dataloader: torch.utils.data.DataLoader, \n",
        "          optimizer: torch.optim.Optimizer,\n",
        "          loss_fn: torch.nn.Module,\n",
        "          epochs: int,\n",
        "          device: torch.device) -> Dict[str, List[float]]:\n",
        "  \"\"\"Trains and tests a PyTorch model.\n",
        "\n",
        "  Passes a target PyTorch models through train_step() and test_step()\n",
        "  functions for a number of epochs, training and testing the model\n",
        "  in the same epoch loop.\n",
        "\n",
        "  Calculates, prints and stores evaluation metrics throughout.\n",
        "\n",
        "  Args:\n",
        "    model: A PyTorch model to be trained and tested.\n",
        "    train_dataloader: A DataLoader instance for the model to be trained on.\n",
        "    test_dataloader: A DataLoader instance for the model to be tested on.\n",
        "    optimizer: A PyTorch optimizer to help minimize the loss function.\n",
        "    loss_fn: A PyTorch loss function to calculate loss on both datasets.\n",
        "    epochs: An integer indicating how many epochs to train for.\n",
        "    device: A target device to compute on (e.g. \"cuda\" or \"cpu\").\n",
        "\n",
        "  Returns:\n",
        "    A dictionary of training and testing loss as well as training and\n",
        "    testing accuracy metrics. Each metric has a value in a list for \n",
        "    each epoch.\n",
        "    In the form: {train_loss: [...],\n",
        "                  train_acc: [...],\n",
        "                  test_loss: [...],\n",
        "                  test_acc: [...]} \n",
        "    For example if training for epochs=2: \n",
        "                 {train_loss: [2.0616, 1.0537],\n",
        "                  train_acc: [0.3945, 0.3945],\n",
        "                  test_loss: [1.2641, 1.5706],\n",
        "                  test_acc: [0.3400, 0.2973]} \n",
        "  \"\"\"\n",
        "  # Create empty results dictionary\n",
        "  results = {\"train_loss\": [],\n",
        "      \"train_acc\": [],\n",
        "      \"test_loss\": [],\n",
        "      \"test_acc\": []\n",
        "  }\n",
        "  \n",
        "  # Loop through training and testing steps for a number of epochs\n",
        "  for epoch in tqdm(range(epochs)):\n",
        "      train_loss, train_acc = train_step(model=model,\n",
        "                                          dataloader=train_dataloader,\n",
        "                                          loss_fn=loss_fn,\n",
        "                                          optimizer=optimizer,\n",
        "                                          device=device)\n",
        "      test_loss, test_acc = test_step(model=model,\n",
        "          dataloader=test_dataloader,\n",
        "          loss_fn=loss_fn,\n",
        "          device=device)\n",
        "      \n",
        "      # Print out what's happening\n",
        "      print(\n",
        "          f\"Epoch: {epoch+1} | \"\n",
        "          f\"train_loss: {train_loss:.4f} | \"\n",
        "          f\"train_acc: {train_acc:.4f} | \"\n",
        "          f\"test_loss: {test_loss:.4f} | \"\n",
        "          f\"test_acc: {test_acc:.4f}\"\n",
        "      )\n",
        "\n",
        "      # Update results dictionary\n",
        "      results[\"train_loss\"].append(train_loss)\n",
        "      results[\"train_acc\"].append(train_acc)\n",
        "      results[\"test_loss\"].append(test_loss)\n",
        "      results[\"test_acc\"].append(test_acc)\n",
        "\n",
        "  # Return the filled results at the end of the epochs\n",
        "  return results"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### 4.1 Turn training functions into a script (`engine.py`)"
      ],
      "metadata": {
        "id": "zMvRLdXN8E3r"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "%%writefile going_modular/engine.py\n",
        "\"\"\"\n",
        "Contains functions for training and testing a PyTorch model.\n",
        "\"\"\"\n",
        "from typing import Dict, List, Tuple\n",
        "\n",
        "import torch\n",
        "\n",
        "from tqdm.auto import tqdm\n",
        "\n",
        "def train_step(model: torch.nn.Module, \n",
        "               dataloader: torch.utils.data.DataLoader, \n",
        "               loss_fn: torch.nn.Module, \n",
        "               optimizer: torch.optim.Optimizer,\n",
        "               device: torch.device) -> Tuple[float, float]:\n",
        "  \"\"\"Trains a PyTorch model for a single epoch.\n",
        "\n",
        "  Turns a target PyTorch model to training mode and then\n",
        "  runs through all of the required training steps (forward\n",
        "  pass, loss calculation, optimizer step).\n",
        "\n",
        "  Args:\n",
        "    model: A PyTorch model to be trained.\n",
        "    dataloader: A DataLoader instance for the model to be trained on.\n",
        "    loss_fn: A PyTorch loss function to minimize.\n",
        "    optimizer: A PyTorch optimizer to help minimize the loss function.\n",
        "    device: A target device to compute on (e.g. \"cuda\" or \"cpu\").\n",
        "\n",
        "  Returns:\n",
        "    A tuple of training loss and training accuracy metrics.\n",
        "    In the form (train_loss, train_accuracy). For example:\n",
        "    \n",
        "    (0.1112, 0.8743)\n",
        "  \"\"\"\n",
        "  # Put model in train mode\n",
        "  model.train()\n",
        "  \n",
        "  # Setup train loss and train accuracy values\n",
        "  train_loss, train_acc = 0, 0\n",
        "  \n",
        "  # Loop through data loader data batches\n",
        "  for batch, (X, y) in enumerate(dataloader):\n",
        "      # Send data to target device\n",
        "      X, y = X.to(device), y.to(device)\n",
        "\n",
        "      # 1. Forward pass\n",
        "      y_pred = model(X)\n",
        "\n",
        "      # 2. Calculate  and accumulate loss\n",
        "      loss = loss_fn(y_pred, y)\n",
        "      train_loss += loss.item() \n",
        "\n",
        "      # 3. Optimizer zero grad\n",
        "      optimizer.zero_grad()\n",
        "\n",
        "      # 4. Loss backward\n",
        "      loss.backward()\n",
        "\n",
        "      # 5. Optimizer step\n",
        "      optimizer.step()\n",
        "\n",
        "      # Calculate and accumulate accuracy metric across all batches\n",
        "      y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)\n",
        "      train_acc += (y_pred_class == y).sum().item()/len(y_pred)\n",
        "\n",
        "  # Adjust metrics to get average loss and accuracy per batch \n",
        "  train_loss = train_loss / len(dataloader)\n",
        "  train_acc = train_acc / len(dataloader)\n",
        "  return train_loss, train_acc\n",
        "\n",
        "def test_step(model: torch.nn.Module, \n",
        "              dataloader: torch.utils.data.DataLoader, \n",
        "              loss_fn: torch.nn.Module,\n",
        "              device: torch.device) -> Tuple[float, float]:\n",
        "  \"\"\"Tests a PyTorch model for a single epoch.\n",
        "\n",
        "  Turns a target PyTorch model to \"eval\" mode and then performs\n",
        "  a forward pass on a testing dataset.\n",
        "\n",
        "  Args:\n",
        "    model: A PyTorch model to be tested.\n",
        "    dataloader: A DataLoader instance for the model to be tested on.\n",
        "    loss_fn: A PyTorch loss function to calculate loss on the test data.\n",
        "    device: A target device to compute on (e.g. \"cuda\" or \"cpu\").\n",
        "\n",
        "  Returns:\n",
        "    A tuple of testing loss and testing accuracy metrics.\n",
        "    In the form (test_loss, test_accuracy). For example:\n",
        "    \n",
        "    (0.0223, 0.8985)\n",
        "  \"\"\"\n",
        "  # Put model in eval mode\n",
        "  model.eval() \n",
        "  \n",
        "  # Setup test loss and test accuracy values\n",
        "  test_loss, test_acc = 0, 0\n",
        "  \n",
        "  # Turn on inference context manager\n",
        "  with torch.inference_mode():\n",
        "      # Loop through DataLoader batches\n",
        "      for batch, (X, y) in enumerate(dataloader):\n",
        "          # Send data to target device\n",
        "          X, y = X.to(device), y.to(device)\n",
        "  \n",
        "          # 1. Forward pass\n",
        "          test_pred_logits = model(X)\n",
        "\n",
        "          # 2. Calculate and accumulate loss\n",
        "          loss = loss_fn(test_pred_logits, y)\n",
        "          test_loss += loss.item()\n",
        "          \n",
        "          # Calculate and accumulate accuracy\n",
        "          test_pred_labels = test_pred_logits.argmax(dim=1)\n",
        "          test_acc += ((test_pred_labels == y).sum().item()/len(test_pred_labels))\n",
        "          \n",
        "  # Adjust metrics to get average loss and accuracy per batch \n",
        "  test_loss = test_loss / len(dataloader)\n",
        "  test_acc = test_acc / len(dataloader)\n",
        "  return test_loss, test_acc\n",
        "\n",
        "\n",
        "def train(model: torch.nn.Module, \n",
        "          train_dataloader: torch.utils.data.DataLoader, \n",
        "          test_dataloader: torch.utils.data.DataLoader, \n",
        "          optimizer: torch.optim.Optimizer,\n",
        "          loss_fn: torch.nn.Module,\n",
        "          epochs: int,\n",
        "          device: torch.device) -> Dict[str, List[float]]:\n",
        "  \"\"\"Trains and tests a PyTorch model.\n",
        "\n",
        "  Passes a target PyTorch models through train_step() and test_step()\n",
        "  functions for a number of epochs, training and testing the model\n",
        "  in the same epoch loop.\n",
        "\n",
        "  Calculates, prints and stores evaluation metrics throughout.\n",
        "\n",
        "  Args:\n",
        "    model: A PyTorch model to be trained and tested.\n",
        "    train_dataloader: A DataLoader instance for the model to be trained on.\n",
        "    test_dataloader: A DataLoader instance for the model to be tested on.\n",
        "    optimizer: A PyTorch optimizer to help minimize the loss function.\n",
        "    loss_fn: A PyTorch loss function to calculate loss on both datasets.\n",
        "    epochs: An integer indicating how many epochs to train for.\n",
        "    device: A target device to compute on (e.g. \"cuda\" or \"cpu\").\n",
        "\n",
        "  Returns:\n",
        "    A dictionary of training and testing loss as well as training and\n",
        "    testing accuracy metrics. Each metric has a value in a list for \n",
        "    each epoch.\n",
        "    In the form: {train_loss: [...],\n",
        "                  train_acc: [...],\n",
        "                  test_loss: [...],\n",
        "                  test_acc: [...]} \n",
        "    For example if training for epochs=2: \n",
        "                 {train_loss: [2.0616, 1.0537],\n",
        "                  train_acc: [0.3945, 0.3945],\n",
        "                  test_loss: [1.2641, 1.5706],\n",
        "                  test_acc: [0.3400, 0.2973]} \n",
        "  \"\"\"\n",
        "  # Create empty results dictionary\n",
        "  results = {\"train_loss\": [],\n",
        "      \"train_acc\": [],\n",
        "      \"test_loss\": [],\n",
        "      \"test_acc\": []\n",
        "  }\n",
        "  \n",
        "  # Loop through training and testing steps for a number of epochs\n",
        "  for epoch in tqdm(range(epochs)):\n",
        "      train_loss, train_acc = train_step(model=model,\n",
        "                                          dataloader=train_dataloader,\n",
        "                                          loss_fn=loss_fn,\n",
        "                                          optimizer=optimizer,\n",
        "                                          device=device)\n",
        "      test_loss, test_acc = test_step(model=model,\n",
        "          dataloader=test_dataloader,\n",
        "          loss_fn=loss_fn,\n",
        "          device=device)\n",
        "      \n",
        "      # Print out what's happening\n",
        "      print(\n",
        "          f\"Epoch: {epoch+1} | \"\n",
        "          f\"train_loss: {train_loss:.4f} | \"\n",
        "          f\"train_acc: {train_acc:.4f} | \"\n",
        "          f\"test_loss: {test_loss:.4f} | \"\n",
        "          f\"test_acc: {test_acc:.4f}\"\n",
        "      )\n",
        "\n",
        "      # Update results dictionary\n",
        "      results[\"train_loss\"].append(train_loss)\n",
        "      results[\"train_acc\"].append(train_acc)\n",
        "      results[\"test_loss\"].append(test_loss)\n",
        "      results[\"test_acc\"].append(test_acc)\n",
        "\n",
        "  # Return the filled results at the end of the epochs\n",
        "  return results"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "LMTp04qP8P84",
        "outputId": "46df6b1c-0156-4911-cb2b-0f293ca909a6"
      },
      "execution_count": 33,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Writing going_modular/engine.py\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from going_modular import engine\n",
        "\n",
        "# engine.train()"
      ],
      "metadata": {
        "id": "VxHYlGQY83o2"
      },
      "execution_count": 35,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AftIAK-zWR35"
      },
      "source": [
        "## 5. Creating a function to save the model\n",
        "\n",
        "Let's setup a function to save our model to a directory."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 36,
      "metadata": {
        "id": "rFRZw6BZWWcd"
      },
      "outputs": [],
      "source": [
        "from pathlib import Path\n",
        "\n",
        "def save_model(model: torch.nn.Module,\n",
        "               target_dir: str,\n",
        "               model_name: str):\n",
        "  \"\"\"Saves a PyTorch model to a target directory.\n",
        "\n",
        "  Args:\n",
        "    model: A target PyTorch model to save.\n",
        "    target_dir: A directory for saving the model to.\n",
        "    model_name: A filename for the saved model. Should include\n",
        "      either \".pth\" or \".pt\" as the file extension.\n",
        "  \n",
        "  Example usage:\n",
        "    save_model(model=model_0,\n",
        "               target_dir=\"models\",\n",
        "               model_name=\"05_going_modular_tingvgg_model.pth\")\n",
        "  \"\"\"\n",
        "  # Create target directory\n",
        "  target_dir_path = Path(target_dir)\n",
        "  target_dir_path.mkdir(parents=True,\n",
        "                        exist_ok=True)\n",
        "  \n",
        "  # Create model save path\n",
        "  assert model_name.endswith(\".pth\") or model_name.endswith(\".pt\"), \"model_name should end with '.pt' or '.pth'\"\n",
        "  model_save_path = target_dir_path / model_name\n",
        "\n",
        "  # Save the model state_dict()\n",
        "  print(f\"[INFO] Saving model to: {model_save_path}\")\n",
        "  torch.save(obj=model.state_dict(),\n",
        "             f=model_save_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### 5.1 Create a file called `utils.py` with utility functions\n",
        "\n",
        "\"utils\" in Python is generally reserved for various utility functions.\n",
        "\n",
        "Right now we only have one utility function (`save_model()`) but... as our code grows we'll likely have more..."
      ],
      "metadata": {
        "id": "Es3tQtQY-BHc"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "%%writefile going_modular/utils.py\n",
        "\"\"\"\n",
        "File containing various utility functions for PyTorch model training.\n",
        "\"\"\" \n",
        "import torch\n",
        "\n",
        "from pathlib import Path\n",
        "\n",
        "def save_model(model: torch.nn.Module,\n",
        "               target_dir: str,\n",
        "               model_name: str):\n",
        "  \"\"\"Saves a PyTorch model to a target directory.\n",
        "\n",
        "  Args:\n",
        "    model: A target PyTorch model to save.\n",
        "    target_dir: A directory for saving the model to.\n",
        "    model_name: A filename for the saved model. Should include\n",
        "      either \".pth\" or \".pt\" as the file extension.\n",
        "  \n",
        "  Example usage:\n",
        "    save_model(model=model_0,\n",
        "               target_dir=\"models\",\n",
        "               model_name=\"05_going_modular_tingvgg_model.pth\")\n",
        "  \"\"\"\n",
        "  # Create target directory\n",
        "  target_dir_path = Path(target_dir)\n",
        "  target_dir_path.mkdir(parents=True,\n",
        "                        exist_ok=True)\n",
        "  \n",
        "  # Create model save path\n",
        "  assert model_name.endswith(\".pth\") or model_name.endswith(\".pt\"), \"model_name should end with '.pt' or '.pth'\"\n",
        "  model_save_path = target_dir_path / model_name\n",
        "\n",
        "  # Save the model state_dict()\n",
        "  print(f\"[INFO] Saving model to: {model_save_path}\")\n",
        "  torch.save(obj=model.state_dict(),\n",
        "             f=model_save_path)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "kOUx-PNW-PVy",
        "outputId": "24dfe175-8f8a-4bba-b7ad-4117a9383126"
      },
      "execution_count": 38,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Overwriting going_modular/utils.py\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "20ZO64U2cZY_"
      },
      "source": [
        "## 6. Train, evaluate and save the model\n",
        "\n",
        "Let's leverage the functions we've got above to train, test and save a model to file.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 173,
          "referenced_widgets": [
            "ac11b3a31c254a448864f2bfce323e57",
            "c5753d59c679446e92f726809bb5b1f3",
            "a5657c5ea7ca496585deefeca5b62e07",
            "69ad063f23594ab1ade5d715fb31f003",
            "92d2ea5f5af14158acb8de92670f277d",
            "d9b7af5921494043a0dbb690ccf04ce0",
            "e5103d2ef47944c7a2f91e3abb5694b1",
            "8ea219a99b5e44c68d28fc43b7a2c165",
            "c761fbfe64da43b3983fd7a50d10f4a3",
            "633797dc11bc42fc958082052d139b08",
            "bd5072dc5ed242a68dd15e3742deee7a"
          ]
        },
        "id": "FjeRMiLjccVc",
        "outputId": "dd0d1232-155a-4e83-9298-c161a789ef8e"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "  0%|          | 0/5 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "ac11b3a31c254a448864f2bfce323e57"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch: 1 | train_loss: 1.0962 | train_acc: 0.3778 | test_loss: 1.0684 | test_acc: 0.4133\n",
            "Epoch: 2 | train_loss: 1.0254 | train_acc: 0.5111 | test_loss: 1.0144 | test_acc: 0.4400\n",
            "Epoch: 3 | train_loss: 0.9479 | train_acc: 0.5333 | test_loss: 0.9846 | test_acc: 0.4800\n",
            "Epoch: 4 | train_loss: 0.9026 | train_acc: 0.5778 | test_loss: 0.9910 | test_acc: 0.4400\n",
            "Epoch: 5 | train_loss: 0.8758 | train_acc: 0.6178 | test_loss: 1.0112 | test_acc: 0.5333\n",
            "[INFO] Total training time: 13.737 seconds\n",
            "[INFO] Saving model to: models/05_going_modular_cell_mode_tinyvgg_model.pth\n"
          ]
        }
      ],
      "source": [
        "# Set random seeds\n",
        "torch.manual_seed(42) \n",
        "torch.cuda.manual_seed(42)\n",
        "\n",
        "# Set number of epochs\n",
        "NUM_EPOCHS = 5\n",
        "\n",
        "# Recreate an instance of TinyVGG\n",
        "model_0 = TinyVGG(input_shape=3, # number of color channels (3 for RGB) \n",
        "                  hidden_units=10, \n",
        "                  output_shape=len(train_data.classes)).to(device)\n",
        "\n",
        "# Setup loss function and optimizer\n",
        "loss_fn = nn.CrossEntropyLoss()\n",
        "optimizer = torch.optim.Adam(params=model_0.parameters(), lr=0.001)\n",
        "\n",
        "# Start the timer\n",
        "from timeit import default_timer as timer \n",
        "start_time = timer()\n",
        "\n",
        "# Train model_0 \n",
        "model_0_results = train(model=model_0, \n",
        "                        train_dataloader=train_dataloader,\n",
        "                        test_dataloader=test_dataloader,\n",
        "                        optimizer=optimizer,\n",
        "                        loss_fn=loss_fn, \n",
        "                        epochs=NUM_EPOCHS,\n",
        "                        device=device)\n",
        "\n",
        "# End the timer and print out how long it took\n",
        "end_time = timer()\n",
        "print(f\"[INFO] Total training time: {end_time-start_time:.3f} seconds\")\n",
        "\n",
        "# Save the model\n",
        "save_model(model=model_0,\n",
        "           target_dir=\"models\",\n",
        "           model_name=\"05_going_modular_cell_mode_tinyvgg_model.pth\")"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### 6.1 Train, evaluate and save the model (script mode) -> `train.py`\n",
        "\n",
        "Let's create a file called `train.py` to leverage all of our other code scripts to train a PyTorch model.\n",
        "\n",
        "Essentially we want to replicate the functionality of notebook 04 in one line..."
      ],
      "metadata": {
        "id": "igk_aKDS_w5Y"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "%%writefile going_modular/train.py\n",
        "\"\"\"\n",
        "Trains a PyTorch image classification model using device-agnostic code.\n",
        "\"\"\"\n",
        "import os\n",
        "import torch\n",
        "\n",
        "from torchvision import transforms\n",
        "from timeit import default_timer as timer \n",
        "\n",
        "import data_setup, engine, model_builder, utils\n",
        "\n",
        "# Setup hyperparameters\n",
        "NUM_EPOCHS = 5 \n",
        "BATCH_SIZE = 32\n",
        "HIDDEN_UNITS = 10\n",
        "LEARNING_RATE = 0.001\n",
        "\n",
        "# Setup directories\n",
        "train_dir = \"data/pizza_steak_sushi/train\"\n",
        "test_dir = \"data/pizza_steak_sushi/test\"\n",
        "\n",
        "# Setup device agnostic code\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "\n",
        "# Create transforms\n",
        "data_transform = transforms.Compose([\n",
        "                                     transforms.Resize((64, 64)),\n",
        "                                     transforms.ToTensor()\n",
        "])\n",
        "\n",
        "# Create DataLoader's and get class_names\n",
        "train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(train_dir=train_dir, \n",
        "                                                                               test_dir=test_dir,\n",
        "                                                                               transform=data_transform,\n",
        "                                                                               batch_size=BATCH_SIZE)\n",
        "\n",
        "# Create model\n",
        "model = model_builder.TinyVGG(input_shape=3,\n",
        "                              hidden_units=HIDDEN_UNITS,\n",
        "                              output_shape=len(class_names)).to(device)\n",
        "\n",
        "# Setup loss and optimizer\n",
        "loss_fn = torch.nn.CrossEntropyLoss()\n",
        "optimizer = torch.optim.Adam(model.parameters(),\n",
        "                             lr=LEARNING_RATE)\n",
        "\n",
        "# Start the timer\n",
        "\n",
        "start_time = timer()\n",
        "\n",
        "# Start training with help from engine.py\n",
        "engine.train(model=model,\n",
        "             train_dataloader=train_dataloader,\n",
        "             test_dataloader=test_dataloader,\n",
        "             loss_fn=loss_fn,\n",
        "             optimizer=optimizer,\n",
        "             epochs=NUM_EPOCHS, \n",
        "             device=device)\n",
        "\n",
        "# End the timer and print out how long it took\n",
        "end_time = timer()\n",
        "print(f\"[INFO] Total training time: {end_time-start_time:.3f} seconds\")\n",
        "\n",
        "# Save the model to file\n",
        "utils.save_model(model=model,\n",
        "                 target_dir=\"models\",\n",
        "                 model_name=\"05_going_modular_script_mode_tinyvgg_model.pth\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "T8uOD8DO__pL",
        "outputId": "0f3dc5f7-ce79-4f6e-fe66-00aef6efd30d"
      },
      "execution_count": 40,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Overwriting going_modular/train.py\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!python going_modular/train.py"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "AlTKB8Xk-wIZ",
        "outputId": "60029b0f-4c9b-411a-912f-c8c93be42795"
      },
      "execution_count": 41,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "  0% 0/5 [00:00<?, ?it/s]Epoch: 1 | train_loss: 1.1019 | train_acc: 0.3203 | test_loss: 1.1089 | test_acc: 0.1979\n",
            " 20% 1/5 [00:01<00:06,  1.51s/it]Epoch: 2 | train_loss: 1.1055 | train_acc: 0.2930 | test_loss: 1.1385 | test_acc: 0.1979\n",
            " 40% 2/5 [00:02<00:04,  1.47s/it]Epoch: 3 | train_loss: 1.1128 | train_acc: 0.2930 | test_loss: 1.1272 | test_acc: 0.1979\n",
            " 60% 3/5 [00:04<00:02,  1.45s/it]Epoch: 4 | train_loss: 1.0973 | train_acc: 0.3906 | test_loss: 1.0982 | test_acc: 0.3021\n",
            " 80% 4/5 [00:05<00:01,  1.45s/it]Epoch: 5 | train_loss: 1.0959 | train_acc: 0.4141 | test_loss: 1.1004 | test_acc: 0.3125\n",
            "100% 5/5 [00:07<00:00,  1.43s/it]\n",
            "[INFO] Total training time: 7.133 seconds\n",
            "[INFO] Saving model to: models/05_going_modular_script_mode_tinyvgg_model.pth\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z6ox2JCLkKZw"
      },
      "source": [
        "We finish with a saved image classification model at `models/05_going_modular_cell_mode_tinyvgg_model.pth`.\n",
        "\n",
        "See exercises and extra-curriculum in section 05 of the Learn PyTorch book - https://www.learnpytorch.io/05_pytorch_going_modular/#exercises\n",
        "\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "05_pytorch_going_modular_cell_mode_video.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.7"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "ac11b3a31c254a448864f2bfce323e57": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_c5753d59c679446e92f726809bb5b1f3",
              "IPY_MODEL_a5657c5ea7ca496585deefeca5b62e07",
              "IPY_MODEL_69ad063f23594ab1ade5d715fb31f003"
            ],
            "layout": "IPY_MODEL_92d2ea5f5af14158acb8de92670f277d"
          }
        },
        "c5753d59c679446e92f726809bb5b1f3": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d9b7af5921494043a0dbb690ccf04ce0",
            "placeholder": "​",
            "style": "IPY_MODEL_e5103d2ef47944c7a2f91e3abb5694b1",
            "value": "100%"
          }
        },
        "a5657c5ea7ca496585deefeca5b62e07": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_8ea219a99b5e44c68d28fc43b7a2c165",
            "max": 5,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_c761fbfe64da43b3983fd7a50d10f4a3",
            "value": 5
          }
        },
        "69ad063f23594ab1ade5d715fb31f003": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_633797dc11bc42fc958082052d139b08",
            "placeholder": "​",
            "style": "IPY_MODEL_bd5072dc5ed242a68dd15e3742deee7a",
            "value": " 5/5 [00:13&lt;00:00,  2.67s/it]"
          }
        },
        "92d2ea5f5af14158acb8de92670f277d": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d9b7af5921494043a0dbb690ccf04ce0": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e5103d2ef47944c7a2f91e3abb5694b1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "8ea219a99b5e44c68d28fc43b7a2c165": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "c761fbfe64da43b3983fd7a50d10f4a3": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "633797dc11bc42fc958082052d139b08": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "bd5072dc5ed242a68dd15e3742deee7a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}